# coding=utf-8

# Copyright 2018 Google AI, Google Brain and Carnegie Mellon University Authors and the HuggingFace Inc. team.

# Copyright (c) 2018, NVIDIA CORPORATION.  All rights reserved.

#

# Licensed under the Apache License, Version 2.0 (the "License");

# you may not use this file except in compliance with the License.

# You may obtain a copy of the License at

#

#     http://www.apache.org/licenses/LICENSE-2.0

#

# Unless required by applicable law or agreed to in writing, software

# distributed under the License is distributed on an "AS IS" BASIS,

# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.

# See the License for the specific language governing permissions and

# limitations under the License.

""" Tokenization classes for Transformer XL model.

    Adapted from https://github.com/kimiyoung/transformer-xl.

"""

from __future__ import (absolute_import, division, print_function,

                        unicode_literals)



import glob

import logging

import os

import sys

from collections import Counter, OrderedDict

from io import open



import numpy as np



from .file_utils import cached_path

from .tokenization_utils import PreTrainedTokenizer



try:

    import torch

except ImportError:

    pass



# if sys.version_info[0] == 2:

#     import cPickle as pickle

# else:

#     import pickle





logger = logging.getLogger(__name__)



VOCAB_FILES_NAMES = {'pretrained_vocab_file': 'vocab.bin', 'vocab_file': 'vocab.txt'}



PRETRAINED_VOCAB_FILES_MAP = {

    'pretrained_vocab_file':

    {

        'transfo-xl-wt103': "https://s3.amazonaws.com/models.huggingface.co/bert/transfo-xl-wt103-vocab.bin",

    }

}



PRETRAINED_POSITIONAL_EMBEDDINGS_SIZES = {

    'transfo-xl-wt103': None,

}



PRETRAINED_CORPUS_ARCHIVE_MAP = {

    'transfo-xl-wt103': "https://s3.amazonaws.com/models.huggingface.co/bert/transfo-xl-wt103-corpus.bin",

}

CORPUS_NAME = 'corpus.bin'



class TransfoXLTokenizer(PreTrainedTokenizer):

    """

    Transformer-XL tokenizer adapted from Vocab class in https://github.com/kimiyoung/transformer-xl

    """

    vocab_files_names = VOCAB_FILES_NAMES

    pretrained_vocab_files_map = PRETRAINED_VOCAB_FILES_MAP

    max_model_input_sizes = PRETRAINED_POSITIONAL_EMBEDDINGS_SIZES



    def __init__(self, special=None, min_freq=0, max_size=None, lower_case=False,

                 delimiter=None, vocab_file=None, pretrained_vocab_file=None,

                 never_split=None, unk_token="<unk>", eos_token="<eos>",

                 additional_special_tokens=["<formula>"], **kwargs):

        super(TransfoXLTokenizer, self).__init__(unk_token=unk_token, eos_token=eos_token,

                                                 additional_special_tokens=additional_special_tokens,

                                                 **kwargs)



        self.max_len_single_sentence = self.max_len # no default special tokens - you can update this value if you add special tokens

        self.max_len_sentences_pair = self.max_len # no default special tokens - you can update this value if you add special tokens



        if never_split is None:

            never_split = self.all_special_tokens

        if special is None:

            special = []

        self.counter = Counter()

        self.special = special

        self.min_freq = min_freq

        self.max_size = max_size

        self.lower_case = lower_case

        self.delimiter = delimiter

        self.vocab_file = vocab_file

        self.never_split = never_split



        if pretrained_vocab_file is not None:

            # Hack because, honestly this tokenizer was not made to be used

            # in a library like ours, at all.

            vocab_dict = torch.load(pretrained_vocab_file)

            for key, value in vocab_dict.items():

                if key not in self.__dict__:

                    self.__dict__[key] = value



        if vocab_file is not None:

            self.build_vocab()



    def count_file(self, path, verbose=False, add_eos=False):

        if verbose: logger.info('counting file {} ...'.format(path))

        assert os.path.exists(path)



        sents = []

        with open(path, 'r', encoding='utf-8') as f:

            for idx, line in enumerate(f):

                if verbose and idx > 0 and idx % 500000 == 0:

                    logger.info('    line {}'.format(idx))

                symbols = self.tokenize(line, add_eos=add_eos)

                self.counter.update(symbols)

                sents.append(symbols)



        return sents



    def count_sents(self, sents, verbose=False):

        """

            sents : a list of sentences, each a list of tokenized symbols

        """

        if verbose: logger.info('counting {} sents ...'.format(len(sents)))

        for idx, symbols in enumerate(sents):

            if verbose and idx > 0 and idx % 500000 == 0:

                logger.info('    line {}'.format(idx))

            self.counter.update(symbols)



    def _build_from_file(self, vocab_file):

        self.idx2sym = []

        self.sym2idx = OrderedDict()



        with open(vocab_file, 'r', encoding='utf-8') as f:

            for line in f:

                symb = line.strip().split()[0]

                self.add_symbol(symb)

        if '<UNK>' in self.sym2idx:

            self.unk_idx = self.sym2idx['<UNK>']

        elif '<unk>' in self.sym2idx:

            self.unk_idx = self.sym2idx['<unk>']

        else:

            raise ValueError('No <unkown> token in vocabulary')



    def save_vocabulary(self, vocab_path):

        """Save the tokenizer vocabulary to a directory or file."""

        if os.path.isdir(vocab_path):

            vocab_file = os.path.join(vocab_path, VOCAB_FILES_NAMES['pretrained_vocab_file'])

        torch.save(self.__dict__, vocab_file)

        return (vocab_file,)



    def build_vocab(self):

        if self.vocab_file:

            logger.info('building vocab from {}'.format(self.vocab_file))

            self._build_from_file(self.vocab_file)

            logger.info('final vocab size {}'.format(len(self)))

        else:

            logger.info('building vocab with min_freq={}, max_size={}'.format(

                self.min_freq, self.max_size))

            self.idx2sym = []

            self.sym2idx = OrderedDict()



            for sym in self.special:

                self.add_special(sym)



            for sym, cnt in self.counter.most_common(self.max_size):

                if cnt < self.min_freq: break

                self.add_symbol(sym)



            logger.info('final vocab size {} from {} unique tokens'.format(

                len(self), len(self.counter)))



    def encode_file(self, path, ordered=False, verbose=False, add_eos=True,

            add_double_eos=False):

        if verbose: logger.info('encoding file {} ...'.format(path))

        assert os.path.exists(path)

        encoded = []

        with open(path, 'r', encoding='utf-8') as f:

            for idx, line in enumerate(f):

                if verbose and idx > 0 and idx % 500000 == 0:

                    logger.info('    line {}'.format(idx))

                symbols = self.tokenize(line, add_eos=add_eos,

                    add_double_eos=add_double_eos)

                encoded.append(self.convert_to_tensor(symbols))



        if ordered:

            encoded = torch.cat(encoded)



        return encoded



    def encode_sents(self, sents, ordered=False, verbose=False):

        if verbose: logger.info('encoding {} sents ...'.format(len(sents)))

        encoded = []

        for idx, symbols in enumerate(sents):

            if verbose and idx > 0 and idx % 500000 == 0:

                logger.info('    line {}'.format(idx))

            encoded.append(self.convert_to_tensor(symbols))



        if ordered:

            encoded = torch.cat(encoded)



        return encoded



    def add_special(self, sym):

        if sym not in self.sym2idx:

            self.idx2sym.append(sym)

            self.sym2idx[sym] = len(self.idx2sym) - 1

            setattr(self, '{}_idx'.format(sym.strip('<>')), self.sym2idx[sym])



    def add_symbol(self, sym):

        if sym not in self.sym2idx:

            self.idx2sym.append(sym)

            self.sym2idx[sym] = len(self.idx2sym) - 1



    def _convert_id_to_token(self, idx):

        """Converts an id in a token (BPE) using the vocab."""

        assert 0 <= idx < len(self), 'Index {} out of vocabulary range'.format(idx)

        return self.idx2sym[idx]



    def _convert_token_to_id(self, sym):

        """ Converts a token (str/unicode) in an id using the vocab. """

        if sym in self.sym2idx:

            return self.sym2idx[sym]

        else:

            # logger.info('encounter unk {}'.format(sym))

            # assert '<eos>' not in sym

            if hasattr(self, 'unk_idx'):

                return self.sym2idx.get(sym, self.unk_idx)

            # Backward compatibility with pre-trained models

            elif '<unk>' in self.sym2idx:

                return self.sym2idx['<unk>']

            elif '<UNK>' in self.sym2idx:

                return self.sym2idx['<UNK>']

            else:

                raise ValueError('Token not in vocabulary and no <unk> token in vocabulary for replacement')



    def convert_tokens_to_string(self, tokens):

        """ Converts a sequence of tokens (string) in a single string. """

        out_string = ' '.join(tokens).strip()

        return out_string



    def convert_to_tensor(self, symbols):

        return torch.LongTensor(self.convert_tokens_to_ids(symbols))



    @property

    def vocab_size(self):

        return len(self.idx2sym)



    def _tokenize(self, line, add_eos=False, add_double_eos=False):

        line = line.strip()

        # convert to lower case

        if self.lower_case:

            line = line.lower()



        # empty delimiter '' will evaluate False

        if self.delimiter == '':

            symbols = line

        else:

            symbols = line.split(self.delimiter)



        if add_double_eos: # lm1b

            return ['<S>'] + symbols + ['<S>']

        elif add_eos:

            return symbols + ['<eos>']

        else:

            return symbols





class LMOrderedIterator(object):

    def __init__(self, data, bsz, bptt, device='cpu', ext_len=None):

        """

            data -- LongTensor -- the LongTensor is strictly ordered

        """

        self.bsz = bsz

        self.bptt = bptt

        self.ext_len = ext_len if ext_len is not None else 0



        self.device = device



        # Work out how cleanly we can divide the dataset into bsz parts.

        self.n_step = data.size(0) // bsz



        # Trim off any extra elements that wouldn't cleanly fit (remainders).

        data = data.narrow(0, 0, self.n_step * bsz)



        # Evenly divide the data across the bsz batches.

        self.data = data.view(bsz, -1).t().contiguous().to(device)



        # Number of mini-batches

        self.n_batch = (self.n_step + self.bptt - 1) // self.bptt



    def get_batch(self, i, bptt=None):

        if bptt is None: bptt = self.bptt

        seq_len = min(bptt, self.data.size(0) - 1 - i)



        end_idx = i + seq_len

        beg_idx = max(0, i - self.ext_len)



        data = self.data[beg_idx:end_idx]

        target = self.data[i+1:i+1+seq_len]



        data_out = data.transpose(0, 1).contiguous().to(self.device)

        target_out = target.transpose(0, 1).contiguous().to(self.device)



        return data_out, target_out, seq_len



    def get_fixlen_iter(self, start=0):

        for i in range(start, self.data.size(0) - 1, self.bptt):

            yield self.get_batch(i)



    def get_varlen_iter(self, start=0, std=5, min_len=5, max_deviation=3):

        max_len = self.bptt + max_deviation * std

        i = start

        while True:

            bptt = self.bptt if np.random.random() < 0.95 else self.bptt / 2.

            bptt = min(max_len, max(min_len, int(np.random.normal(bptt, std))))

            data, target, seq_len = self.get_batch(i, bptt)

            i += seq_len

            yield data, target, seq_len

            if i >= self.data.size(0) - 2:

                break



    def __iter__(self):

        return self.get_fixlen_iter()





class LMShuffledIterator(object):

    def __init__(self, data, bsz, bptt, device='cpu', ext_len=None, shuffle=False):

        """

            data -- list[LongTensor] -- there is no order among the LongTensors

        """

        self.data = data



        self.bsz = bsz

        self.bptt = bptt

        self.ext_len = ext_len if ext_len is not None else 0



        self.device = device

        self.shuffle = shuffle



    def get_sent_stream(self):

        # index iterator

        epoch_indices = np.random.permutation(len(self.data)) if self.shuffle \
            else np.array(range(len(self.data)))



        # sentence iterator

        for idx in epoch_indices:

            yield self.data[idx]



    def stream_iterator(self, sent_stream):

        # streams for each data in the batch

        streams = [None] * self.bsz



        data = torch.LongTensor(self.bptt, self.bsz)

        target = torch.LongTensor(self.bptt, self.bsz)



        n_retain = 0



        while True:

            # data   : [n_retain+bptt x bsz]

            # target : [bptt x bsz]

            data[n_retain:].fill_(-1)

            target.fill_(-1)



            valid_batch = True



            for i in range(self.bsz):

                n_filled = 0

                try:
                    while n_filled < self.bptt:
                        if streams[i] is None or len(streams[i]) <= 1:
                            streams[i] = next(sent_stream)
                        # number of new tokens to fill in
                        n_new = min(len(streams[i]) - 1, self.bptt - n_filled)
                        # first n_retain tokens are retained from last batch
                        data[n_retain+n_filled:n_retain+n_filled+n_new, i] = \
                            streams[i][:n_new]
                        target[n_filled:n_filled+n_new, i] = \
                            streams[i][1:n_new+1]
                        streams[i] = streams[i][n_new:]

                        n_filled += n_new

                except StopIteration:

                    valid_batch = False

                    break



            if not valid_batch:

                return



            data_out = data.transpose(0, 1).contiguous().to(self.device)

            target_out = target.transpose(0, 1).contiguous().to(self.device)



            yield data_out, target_out, self.bptt



            n_retain = min(data.size(0), self.ext_len)

            if n_retain > 0:

                data[:n_retain] = data[-n_retain:]

            data.resize_(n_retain + self.bptt, data.size(1))



    def __iter__(self):

        # sent_stream is an iterator

        sent_stream = self.get_sent_stream()



        for batch in self.stream_iterator(sent_stream):

            yield batch





class LMMultiFileIterator(LMShuffledIterator):

    def __init__(self, paths, vocab, bsz, bptt, device='cpu', ext_len=None,

        shuffle=False):



        self.paths = paths

        self.vocab = vocab



        self.bsz = bsz

        self.bptt = bptt

        self.ext_len = ext_len if ext_len is not None else 0



        self.device = device

        self.shuffle = shuffle



    def get_sent_stream(self, path):

        sents = self.vocab.encode_file(path, add_double_eos=True)

        if self.shuffle:

            np.random.shuffle(sents)

        sent_stream = iter(sents)



        return sent_stream



    def __iter__(self):

        if self.shuffle:

            np.random.shuffle(self.paths)



        for path in self.paths:

            # sent_stream is an iterator

            sent_stream = self.get_sent_stream(path)

            for batch in self.stream_iterator(sent_stream):

                yield batch





class TransfoXLCorpus(object):

    @classmethod

    def from_pretrained(cls, pretrained_model_name_or_path, cache_dir=None, *inputs, **kwargs):

        """

        Instantiate a pre-processed corpus.

        """

        vocab = TransfoXLTokenizer.from_pretrained(pretrained_model_name_or_path, *inputs, **kwargs)

        if pretrained_model_name_or_path in PRETRAINED_CORPUS_ARCHIVE_MAP:

            corpus_file = PRETRAINED_CORPUS_ARCHIVE_MAP[pretrained_model_name_or_path]

        else:

            corpus_file = os.path.join(pretrained_model_name_or_path, CORPUS_NAME)

        # redirect to the cache, if necessary

        try:

            resolved_corpus_file = cached_path(corpus_file, cache_dir=cache_dir)

        except EnvironmentError:

            logger.error(

                "Corpus '{}' was not found in corpus list ({}). "

                "We assumed '{}' was a path or url but couldn't find files {} "

                "at this path or url.".format(

                    pretrained_model_name_or_path,

                    ', '.join(PRETRAINED_CORPUS_ARCHIVE_MAP.keys()),

                    pretrained_model_name_or_path,

                    corpus_file))

            return None

        if resolved_corpus_file == corpus_file:

            logger.info("loading corpus file {}".format(corpus_file))

        else:

            logger.info("loading corpus file {} from cache at {}".format(

                corpus_file, resolved_corpus_file))



        # Instantiate tokenizer.

        corpus = cls(*inputs, **kwargs)

        corpus_dict = torch.load(resolved_corpus_file)

        for key, value in corpus_dict.items():

            corpus.__dict__[key] = value

        corpus.vocab = vocab

        if corpus.train is not None:

            corpus.train = torch.tensor(corpus.train, dtype=torch.long)

        if corpus.valid is not None:

            corpus.valid = torch.tensor(corpus.valid, dtype=torch.long)

        if corpus.test is not None:

            corpus.test = torch.tensor(corpus.test, dtype=torch.long)

        return corpus



    def __init__(self, *args, **kwargs):

        self.vocab = TransfoXLTokenizer(*args, **kwargs)

        self.dataset = None

        self.train = None

        self.valid = None

        self.test = None



    def build_corpus(self, path, dataset):

        self.dataset = dataset



        if self.dataset in ['ptb', 'wt2', 'enwik8', 'text8']:

            self.vocab.count_file(os.path.join(path, 'train.txt'))

            self.vocab.count_file(os.path.join(path, 'valid.txt'))

            self.vocab.count_file(os.path.join(path, 'test.txt'))

        elif self.dataset == 'wt103':

            self.vocab.count_file(os.path.join(path, 'train.txt'))

        elif self.dataset == 'lm1b':

            train_path_pattern = os.path.join(

                path, '1-billion-word-language-modeling-benchmark-r13output',

                'training-monolingual.tokenized.shuffled', 'news.en-*')

            train_paths = glob.glob(train_path_pattern)

            # the vocab will load from file when build_vocab() is called



        self.vocab.build_vocab()



        if self.dataset in ['ptb', 'wt2', 'wt103']:

            self.train = self.vocab.encode_file(

                os.path.join(path, 'train.txt'), ordered=True)

            self.valid = self.vocab.encode_file(

                os.path.join(path, 'valid.txt'), ordered=True)

            self.test = self.vocab.encode_file(

                os.path.join(path, 'test.txt'), ordered=True)

        elif self.dataset in ['enwik8', 'text8']:

            self.train = self.vocab.encode_file(

                os.path.join(path, 'train.txt'), ordered=True, add_eos=False)

            self.valid = self.vocab.encode_file(

                os.path.join(path, 'valid.txt'), ordered=True, add_eos=False)

            self.test = self.vocab.encode_file(

                os.path.join(path, 'test.txt'), ordered=True, add_eos=False)

        elif self.dataset == 'lm1b':

            self.train = train_paths

            self.valid = self.vocab.encode_file(

                os.path.join(path, 'valid.txt'), ordered=False, add_double_eos=True)

            self.test = self.vocab.encode_file(

                os.path.join(path, 'test.txt'), ordered=False, add_double_eos=True)



    def get_iterator(self, split, *args, **kwargs):

        if split == 'train':

            if self.dataset in ['ptb', 'wt2', 'wt103', 'enwik8', 'text8']:

                data_iter = LMOrderedIterator(self.train, *args, **kwargs)

            elif self.dataset == 'lm1b':

                kwargs['shuffle'] = True

                data_iter = LMMultiFileIterator(self.train, self.vocab, *args, **kwargs)

        elif split in ['valid', 'test']:

            data = self.valid if split == 'valid' else self.test

            if self.dataset in ['ptb', 'wt2', 'wt103', 'enwik8', 'text8']:

                data_iter = LMOrderedIterator(data, *args, **kwargs)

            elif self.dataset == 'lm1b':

                data_iter = LMShuffledIterator(data, *args, **kwargs)



        return data_iter





def get_lm_corpus(datadir, dataset):

    fn = os.path.join(datadir, 'cache.pt')

    fn_pickle = os.path.join(datadir, 'cache.pkl')

    if os.path.exists(fn):

        logger.info('Loading cached dataset...')

        corpus = torch.load(fn_pickle)

    elif os.path.exists(fn):

        logger.info('Loading cached dataset from pickle...')

        with open(fn, "rb") as fp:

            corpus = pickle.load(fp)

    else:

        logger.info('Producing dataset {}...'.format(dataset))

        kwargs = {}

        if dataset in ['wt103', 'wt2']:

            kwargs['special'] = ['<eos>']

            kwargs['lower_case'] = False

        elif dataset == 'ptb':

            kwargs['special'] = ['<eos>']

            kwargs['lower_case'] = True

        elif dataset == 'lm1b':

            kwargs['special'] = []

            kwargs['lower_case'] = False

            kwargs['vocab_file'] = os.path.join(datadir, '1b_word_vocab.txt')

        elif dataset in ['enwik8', 'text8']:

            pass



        corpus = TransfoXLCorpus(datadir, dataset, **kwargs)

        torch.save(corpus, fn)



    return corpus

